#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar  4 13:47:01 2022

Simulation of Amplitude
"""
import ot
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def generate_uniform_sphere(d,n,R):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,1,size=(1,d))
        data[j] = R*temp/np.linalg.norm(temp)
    return data


def generate_uniform_ellipsoid(d,n,sigma):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,sigma)
        data[j] = temp/np.sqrt(sum(temp**2/sigma**2))
    return data

    


d = 3


sigma = np.array([2,0.5,1])

n = 1000

ns = 1000
N = 2000

xs = np.linspace(-2,2,400)



n = 600
a, b = np.ones((n,)) / n, np.ones((n,)) / n
ns = 500
amp = np.empty((ns,))
rvar = 16.43
xs = np.linspace(-15,15,2000)
limSdens = np.exp(-xs**2/(2*rvar))/np.sqrt(2*rvar*np.pi)
n_seed = 20
for i in range(ns):
    #generate samples from uniform distributions over ellipsoid and unit sphere
    datap = generate_uniform_ellipsoid(d,n,np.array([2,2,4]))
    dataq = generate_uniform_sphere(d,n,1)       
    #compute the max-sliced Wasserstein distance between the empirical distributions
    amp[i] = ot.sliced.max_sliced_wasserstein_distance(datap, dataq, a, b, 1000, seed=np.random.randint(50))**2 
    #compute the "min-sliced Wasserstein distance" between the empirical distributions
    v0 = 1/3
    angles = generate_uniform_sphere(d,N,1)
    for j in range(N):
        u = angles[j]
        pu = np.dot(datap,u)
        qu = np.dot(dataq,u)
        v0 = min(v0,ot.emd2_1d(pu,qu,a,b))
    amp[i] -= v0
amp = np.sqrt(n)*(amp - 8/3) 
density = gaussian_kde(amp,'silverman')
plt.plot(xs,density(xs),color='cadetblue')
plt.fill_between(xs, density(xs),color='paleturquoise',alpha=0.5)
plt.plot(xs,limSdens,color='palevioletred')
plt.fill_between(xs,limSdens,color='pink',alpha=0.5)
plt.xlabel("x")
plt.ylabel("Density")
plt.title('amplitude')


    
        
        
